{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetune XLNET-Bahasa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This tutorial is available as an IPython notebook at [Malaya/finetune/xlnet](https://github.com/huseinzol05/Malaya/tree/master/finetune/xlnet).\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, I will going to show to finetune pretrained XLNET-Bahasa using Tensorflow Estimator.\n",
    "\n",
    "TF-Estimator is really a great module created by Tensorflow Team to train a model for a very long period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install tensorflow==1.15 xlnet-tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download pretrained model\n",
    "\n",
    "https://github.com/huseinzol05/Malaya/tree/master/pretrained-model/xlnet#download, In this example, we are going to try BASE size. Just uncomment below to download pretrained model and tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sp10m.cased.v9.model\t\t\txlnet-base-500k-20-10-2020.gz\r\n",
      "tf-estimator-text-classification.ipynb\txlnet-base_config.json\r\n",
      "xlnet-base\r\n"
     ]
    }
   ],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/xlnet-base-500k-20-10-2020.gz\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.v9.model\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/xlnet/config/xlnet-base_config.json\n",
    "# !tar -zxf xlnet-base-500k-20-10-2020.gz\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.ckpt-500000.data-00000-of-00001  model.ckpt-500000.meta\r\n",
      "model.ckpt-500000.index\t\t       xlnet-base_config.json\r\n"
     ]
    }
   ],
   "source": [
    "!ls xlnet-base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a helper function [malaya/finetune/utils.py](https://github.com/huseinzol05/Malaya/blob/master/finetune/utils.py) to help us to train the model on single GPU or multiGPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, '../')\n",
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just going to train on very small news bahasa sentiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Lebih-lebih lagi dengan  kemudahan internet da...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Positive</td>\n",
       "      <td>boleh memberi teguran kepada parti tetapi perl...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Adalah membingungkan mengapa masyarakat Cina b...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Positive</td>\n",
       "      <td>Kami menurunkan defisit daripada 6.7 peratus p...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Ini masalahnya. Bukan rakyat, tetapi sistem</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                               text\n",
       "0  Negative  Lebih-lebih lagi dengan  kemudahan internet da...\n",
       "1  Positive  boleh memberi teguran kepada parti tetapi perl...\n",
       "2  Negative  Adalah membingungkan mengapa masyarakat Cina b...\n",
       "3  Positive  Kami menurunkan defisit daripada 6.7 peratus p...\n",
       "4  Negative        Ini masalahnya. Bukan rakyat, tetapi sistem"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../sentiment-data-v2.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Negative', 'Positive']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = df['label'].values.tolist()\n",
    "texts = df['text'].values.tolist()\n",
    "unique_labels = sorted(list(set(labels)))\n",
    "unique_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from xlnet import model_utils\n",
    "from xlnet import xlnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "from xlnet.prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')\n",
    "\n",
    "SEG_ID_A = 0\n",
    "SEG_ID_B = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    '<unk>': 0,\n",
    "    '<s>': 1,\n",
    "    '</s>': 2,\n",
    "    '<cls>': 3,\n",
    "    '<sep>': 4,\n",
    "    '<pad>': 5,\n",
    "    '<mask>': 6,\n",
    "    '<eod>': 7,\n",
    "    '<eop>': 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols['<unk>']\n",
    "CLS_ID = special_symbols['<cls>']\n",
    "SEP_ID = special_symbols['<sep>']\n",
    "MASK_ID = special_symbols['<mask>']\n",
    "EOD_ID = special_symbols['<eod>']\n",
    "\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower = False)\n",
    "    return encode_ids(sp_model, text)\n",
    "\n",
    "\n",
    "def token_to_ids(text, maxlen = 512):\n",
    "    tokens_a = tokenize_fn(text)\n",
    "    if len(tokens_a) > maxlen - 2:\n",
    "        tokens_a = tokens_a[: (maxlen - 2)]\n",
    "    segment_id = [SEG_ID_A] * len(tokens_a)\n",
    "    tokens_a.append(SEP_ID)\n",
    "    tokens_a.append(CLS_ID)\n",
    "    segment_id.append(SEG_ID_A)\n",
    "    segment_id.append(SEG_ID_CLS)\n",
    "    input_mask = [0.0] * len(tokens_a)\n",
    "    assert len(tokens_a) == len(input_mask) == len(segment_id)\n",
    "    return {\n",
    "        'input_id': tokens_a,\n",
    "        'input_mask': input_mask,\n",
    "        'segment_id': segment_id,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. `input_id`, integer representation of tokenized words, sorted based on sentencepiece weightage.\n",
    "2. `input_mask`, attention masking. During training, short words will padded with `1`, so we do not want the model learn padded values as part of the context. https://github.com/zihangdai/xlnet/blob/master/classifier_utils.py#L113\n",
    "3. `segment_id`, Use for text pair classification, in this case, we can simply put `0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': [1620,\n",
       "  13,\n",
       "  5177,\n",
       "  53,\n",
       "  33,\n",
       "  2808,\n",
       "  3168,\n",
       "  24,\n",
       "  3400,\n",
       "  807,\n",
       "  21,\n",
       "  16179,\n",
       "  31,\n",
       "  742,\n",
       "  578,\n",
       "  17153,\n",
       "  9,\n",
       "  4,\n",
       "  3],\n",
       " 'input_mask': [0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_to_ids(texts[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF-Estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF-Estimator, required 2 parts,\n",
    "\n",
    "1. Input pipeline, https://www.tensorflow.org/api_docs/python/tf/data/Dataset\n",
    "2. Model definition, https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate():\n",
    "    while True:\n",
    "        for i in range(len(texts)):\n",
    "            if len(texts[i]) > 5:\n",
    "                d = token_to_ids(texts[i])\n",
    "                d['label'] = [unique_labels.index(labels[i])]\n",
    "                d.pop('tokens', None)\n",
    "                yield d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': [1620,\n",
       "  13,\n",
       "  5177,\n",
       "  53,\n",
       "  33,\n",
       "  2808,\n",
       "  3168,\n",
       "  24,\n",
       "  3400,\n",
       "  807,\n",
       "  21,\n",
       "  16179,\n",
       "  31,\n",
       "  742,\n",
       "  578,\n",
       "  17153,\n",
       "  9,\n",
       "  4,\n",
       "  3],\n",
       " 'input_mask': [0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
       " 'label': [0]}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = generate()\n",
    "next(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function return a function.\n",
    "\n",
    "```python\n",
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        return dataset\n",
    "    return get\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        dataset = tf.data.Dataset.from_generator(\n",
    "            generate,\n",
    "            {'input_id': tf.int32, 'input_mask': tf.float32, 'segment_id': tf.int32, 'label': tf.int32},\n",
    "            output_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "        )\n",
    "        dataset = dataset.shuffle(shuffle_size)\n",
    "        dataset = dataset.padded_batch(\n",
    "            batch_size,\n",
    "            padded_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "            padding_values = {\n",
    "                'input_id': tf.constant(0, dtype = tf.int32),\n",
    "                'input_mask': tf.constant(1.0, dtype = tf.float32),\n",
    "                'segment_id': tf.constant(4, dtype = tf.int32),\n",
    "                'label': tf.constant(0, dtype = tf.int32),\n",
    "            },\n",
    "        )\n",
    "        return dataset\n",
    "    return get"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test data pipeline using tf.session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-17-2f00f4f10c26>:4: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "iterator = get_dataset()()\n",
    "iterator = iterator.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>,\n",
       " 'input_mask': <tf.Tensor 'IteratorGetNext:1' shape=(?, ?) dtype=float32>,\n",
       " 'segment_id': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>,\n",
       " 'label': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?) dtype=int32>}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': array([[1084,  791,  835, ...,    0,    0,    0],\n",
       "        [ 256, 8993,    9, ...,    0,    0,    0],\n",
       "        [8110,   87, 1743, ...,    0,    0,    0],\n",
       "        ...,\n",
       "        [ 767,  250,   51, ...,    0,    0,    0],\n",
       "        [ 398, 8269,  742, ...,    9,    4,    3],\n",
       "        [3593,   21, 7901, ...,    0,    0,    0]], dtype=int32),\n",
       " 'input_mask': array([[0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.]], dtype=float32),\n",
       " 'segment_id': array([[0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        ...,\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 0, 0, 2],\n",
       "        [0, 0, 0, ..., 4, 4, 4]], dtype=int32),\n",
       " 'label': array([[0],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1]], dtype=int32)}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function accepts 4 parameters.\n",
    "\n",
    "```python\n",
    "def model_fn(features, labels, mode, params):\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/xlnet.py:64: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "    is_training = True,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.1,\n",
    "    dropatt = 0.1,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clamp_len = -1,\n",
    ")\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path = 'xlnet-base_config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 32\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = 10\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "training_parameters = dict(\n",
    "    decay_method = 'poly',\n",
    "    train_steps = num_train_steps,\n",
    "    learning_rate = learning_rate,\n",
    "    warmup_steps = num_warmup_steps,\n",
    "    min_lr_ratio = 0.0,\n",
    "    weight_decay = 0.00,\n",
    "    adam_epsilon = 1e-8,\n",
    "    num_core_per_host = 1,\n",
    "    lr_layer_decay_rate = 1,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.0,\n",
    "    dropatt = 0.0,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clip = 1.0,\n",
    "    clamp_len = -1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(\n",
    "        self,\n",
    "        decay_method,\n",
    "        warmup_steps,\n",
    "        weight_decay,\n",
    "        adam_epsilon,\n",
    "        num_core_per_host,\n",
    "        lr_layer_decay_rate,\n",
    "        use_tpu,\n",
    "        learning_rate,\n",
    "        train_steps,\n",
    "        min_lr_ratio,\n",
    "        clip,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "\n",
    "\n",
    "training_parameters = Parameter(**training_parameters)\n",
    "init_checkpoint = 'xlnet-base/model.ckpt-500000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode, params):\n",
    "    Y = tf.cast(features['label'][:, 0], tf.int32)\n",
    "\n",
    "    xlnet_model = xlnet.XLNetModel(\n",
    "        xlnet_config = xlnet_config,\n",
    "        run_config = xlnet_parameters,\n",
    "        input_ids = tf.transpose(features['input_id'], [1, 0]),\n",
    "        seg_ids = tf.transpose(features['segment_id'], [1, 0]),\n",
    "        input_mask = tf.transpose(features['input_mask'], [1, 0]),\n",
    "    )\n",
    "\n",
    "    output_layer = xlnet_model.get_sequence_output()\n",
    "    output_layer = tf.transpose(output_layer, [1, 0, 2])\n",
    "\n",
    "    logits_seq = tf.layers.dense(output_layer, 2)\n",
    "    logits = logits_seq[:, 0]\n",
    "\n",
    "    loss = tf.reduce_mean(\n",
    "        tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits = logits, labels = Y\n",
    "        )\n",
    "    )\n",
    "\n",
    "    tf.identity(loss, 'train_loss')\n",
    "\n",
    "    accuracy = tf.metrics.accuracy(\n",
    "        labels = Y, predictions = tf.argmax(logits, axis = 1)\n",
    "    )\n",
    "    tf.identity(accuracy[1], name = 'train_accuracy')\n",
    "\n",
    "    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "\n",
    "    assignment_map, initialized_variable_names = utils.get_assignment_map_from_checkpoint(\n",
    "        variables, init_checkpoint\n",
    "    )\n",
    "\n",
    "    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
    "\n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        train_op, _, _ = model_utils.get_train_op(training_parameters, loss)\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = mode, loss = loss, train_op = train_op\n",
    "        )\n",
    "\n",
    "    elif mode == tf.estimator.ModeKeys.EVAL:\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = tf.estimator.ModeKeys.EVAL,\n",
    "            loss = loss,\n",
    "            eval_metric_ops = {'accuracy': accuracy},\n",
    "        )\n",
    "\n",
    "    return estimator_spec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initiate training session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From ../utils.py:62: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n",
      "\n",
      "WARNING:tensorflow:From ../utils.py:62: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n",
      "\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'finetuned-xlnet-base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 10, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f31fb236fd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/xlnet.py:221: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/xlnet.py:221: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/modeling.py:453: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/modeling.py:460: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/modeling.py:535: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/modeling.py:67: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "INFO:tensorflow:**** Trainable Variables ****\n",
      "INFO:tensorflow:  name = model/transformer/r_w_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/r_r_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/word_embedding/lookup_table:0, shape = (32000, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/r_s_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/seg_embed:0, shape = (12, 2, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_0/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_1/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_2/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_3/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_4/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_5/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_6/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_7/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_8/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_9/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_10/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_11/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = dense/kernel:0, shape = (768, 2)\n",
      "INFO:tensorflow:  name = dense/bias:0, shape = (2,)\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/model_utils.py:96: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/model_utils.py:108: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/model_utils.py:123: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/xlnet/model_utils.py:131: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into finetuned-xlnet-base/model.ckpt.\n",
      "INFO:tensorflow:train_accuracy = 0.5, train_loss = 0.8626036\n",
      "INFO:tensorflow:loss = 0.8626036, step = 1\n"
     ]
    }
   ],
   "source": [
    "train_hooks = [\n",
    "    tf.train.LoggingTensorHook(\n",
    "        ['train_accuracy', 'train_loss'], every_n_iter = 1\n",
    "    )\n",
    "]\n",
    "utils.run_training(\n",
    "    train_fn = train_dataset,\n",
    "    model_fn = model_fn,\n",
    "    model_dir = 'finetuned-xlnet-base',\n",
    "    num_gpus = 1,\n",
    "    log_step = 1,\n",
    "    save_checkpoint_step = epoch,\n",
    "    max_steps = epoch,\n",
    "    train_hooks = train_hooks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
